-- drop function if exists sm_sc.fv_d_aggr_slice_max_dloss_dindepdt(anyarray, anyarray, anyarray, int[]);
create or replace function sm_sc.fv_d_aggr_slice_max_dloss_dindepdt
(
  i_indepdt        anyarray,
  i_depdt          anyarray,
  i_dloss_ddepdt   anyarray,
  i_cnt_per_grp    int[]
)
returns anyarray
as
$$
declare   
  v_dloss_ddepdt_len_y    int         := array_length(i_dloss_ddepdt, 1);
  v_dloss_ddepdt_len_x    int         := array_length(i_dloss_ddepdt, 2);
  v_dloss_ddepdt_len_x3   int         := array_length(i_dloss_ddepdt, 3);
  v_dloss_ddepdt_len_x4   int         := array_length(i_dloss_ddepdt, 4);
  v_ret                   float[]     ;
  v_cur_y                 int         ;
  v_cur_x                 int         ;
  v_cur_x3                int         ;
  v_cur_x4                int         ;
  v_indepdt_len           int[]       := 
    (
      select 
        array_agg(array_length(i_indepdt, a_cur_dim) order by a_cur_dim) 
      from generate_series(1, array_ndims(i_indepdt)) tb_a_cur_dim(a_cur_dim)
    )
  ;
begin
  i_cnt_per_grp := 
    sm_sc.fv_coalesce
    (
      v_indepdt_len[ : array_length(v_indepdt_len, 1) - coalesce(array_length(i_cnt_per_grp, 1), 0)] || i_cnt_per_grp
    , v_indepdt_len
    )
  ;
  if current_setting('pg4ml._v_is_debug_check', true) = '1'
  then
    if array_ndims(i_cnt_per_grp) > 1 
    then 
      raise exception 'unsupport ndims of i_cnt_per_grp > 1.';
    elsif array_ndims(i_dloss_ddepdt) <> array_length(i_cnt_per_grp, 1)
    then 
      raise exception 'unmatch between ndims of i_dloss_ddepdt and length of i_cnt_per_grp.';
    elsif array_dims(i_depdt) <> array_dims(i_dloss_ddepdt)
    then 
      raise exception 'unmatch between dims of i_depdt and i_dloss_ddepdt.';
    elsif 
      0 <> any 
      (
        v_indepdt_len
        %` i_cnt_per_grp
      )
    then 
      raise exception 'unperfect i_indepdt''s length for i_cnt_per_grp at some dims';
    elsif 
      (
        select 
          array_agg(array_length(i_dloss_ddepdt, a_cur_dim) order by a_cur_dim) 
        from generate_series(1, array_ndims(i_dloss_ddepdt)) tb_a_cur_dim(a_cur_dim)
      )
      <>  
      (
        v_indepdt_len
        /` i_cnt_per_grp
      )
    then 
      raise exception 'unmatch length between i_dloss_ddepdt and i_indepdt devided by i_cnt_per_grp';
    end if;
  end if;
  
  if i_depdt is null
  then 
    i_depdt := sm_sc.fv_aggr_slice_max(i_indepdt, i_cnt_per_grp);
  end if;
    
  if i_dloss_ddepdt is null
  then 
    return null;
    
  elsif array_length(i_cnt_per_grp, 1) = 1
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y] *` i_cnt_per_grp);
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      v_ret
        [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
      :=
        (
          i_indepdt
            [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]] 
          ==` 
          i_depdt
            [v_cur_y]
        ) :: int[] :: float[]
        *` 
        i_dloss_ddepdt
          [v_cur_y]
      ;        
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 2
  then  
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        v_ret
          [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
          [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
        :=
          (
            i_indepdt
              [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]] 
              [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
            ==` 
            i_depdt
              [v_cur_y]
              [v_cur_x]
          ) :: int[] :: float[]
          *` 
          i_dloss_ddepdt
            [v_cur_y]
            [v_cur_x]
        ;        
      end loop;
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 3
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x, v_dloss_ddepdt_len_x3] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        for v_cur_x3 in 1 .. v_dloss_ddepdt_len_x3
        loop 
          v_ret
            [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
            [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
            [(v_cur_x3 - 1) * i_cnt_per_grp[3] + 1 : v_cur_x3 * i_cnt_per_grp[3]]
          :=
            (
              i_indepdt
                [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]] 
                [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
                [(v_cur_x3 - 1) * i_cnt_per_grp[3] + 1 : v_cur_x3 * i_cnt_per_grp[3]]
              ==` 
              i_depdt
                [v_cur_y]
                [v_cur_x]
                [v_cur_x3]
            ) :: int[] :: float[]
            *` 
            i_dloss_ddepdt
              [v_cur_y]
              [v_cur_x]
              [v_cur_x3]
          ;        
        end loop;
      end loop;
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 4
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x, v_dloss_ddepdt_len_x3, v_dloss_ddepdt_len_x4] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        for v_cur_x3 in 1 .. v_dloss_ddepdt_len_x3
        loop 
          for v_cur_x4 in 1 .. v_dloss_ddepdt_len_x4
          loop 
            v_ret
              [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
              [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
              [(v_cur_x3 - 1) * i_cnt_per_grp[3] + 1 : v_cur_x3 * i_cnt_per_grp[3]]
              [(v_cur_x4 - 1) * i_cnt_per_grp[4] + 1 : v_cur_x4 * i_cnt_per_grp[4]]
            :=
              (
                i_indepdt
                  [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]] 
                  [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
                  [(v_cur_x3 - 1) * i_cnt_per_grp[3] + 1 : v_cur_x3 * i_cnt_per_grp[3]]
                  [(v_cur_x4 - 1) * i_cnt_per_grp[4] + 1 : v_cur_x4 * i_cnt_per_grp[4]]
                ==` 
                i_depdt
                  [v_cur_y]
                  [v_cur_x]
                  [v_cur_x3]
                  [v_cur_x4]
              ) :: int[] :: float[]
              *` 
              i_dloss_ddepdt
                [v_cur_y]
                [v_cur_x]
                [v_cur_x3]
                [v_cur_x4]
            ;        
          end loop;
        end loop;
      end loop;
    end loop;
    
  end if;
  
  return v_ret;
end
$$
language plpgsql stable
parallel safe
cost 100;

-- with 
-- cte_rand as 
-- (
--   select 
--     sm_sc.fv_new_rand(array[6]) as a_indepdt
-- )
-- select 
--   sm_sc.fv_d_aggr_slice_max_dloss_dindepdt
--   (
--     a_indepdt
--   , sm_sc.fv_aggr_slice_max(a_indepdt, array[3])
--   , sm_sc.fv_new_rand(array[2])
--   , array[3]
--   ) :: decimal[] ~=` 6
-- from cte_rand

-- with 
-- cte_rand as 
-- (
--   select 
--     sm_sc.fv_new_rand(array[6, 8]) as a_indepdt
-- )
-- select 
--   sm_sc.fv_d_aggr_slice_max_dloss_dindepdt
--   (
--     a_indepdt
--   , sm_sc.fv_aggr_slice_max(a_indepdt, array[3, 2])
--   , sm_sc.fv_new_rand(array[2, 4])
--   , array[3, 2]
--   ) :: decimal[] ~=` 6
-- from cte_rand

-- with 
-- cte_rand as 
-- (
--   select 
--     sm_sc.fv_new_rand(array[6, 8, 8]) as a_indepdt
-- )
-- select 
--   sm_sc.fv_d_aggr_slice_max_dloss_dindepdt
--   (
--     a_indepdt
--   , sm_sc.fv_aggr_slice_max(a_indepdt, array[3, 2, 2])
--   , sm_sc.fv_new_rand(array[2, 4, 4])
--   , array[3, 2, 2]
--   ) :: decimal[] ~=` 6
-- from cte_rand

-- with 
-- cte_rand as 
-- (
--   select 
--     sm_sc.fv_new_rand(array[6, 8, 8, 6]) as a_indepdt
-- )
-- select 
--   sm_sc.fv_d_aggr_slice_max_dloss_dindepdt
--   (
--     a_indepdt
--   , sm_sc.fv_aggr_slice_max(a_indepdt, array[3, 2, 2, 3])
--   , sm_sc.fv_new_rand(array[2, 4, 4, 2])
--   , array[3, 2, 2, 3]
--   ) :: decimal[] ~=` 6
-- from cte_rand